{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# REINFORCE in lasagne\n",
    "\n",
    "Just like we did before for q-learning, this time we'll design a lasagne network to learn `CartPole-v0` via policy gradient (REINFORCE).\n",
    "\n",
    "Most of the code in this notebook is taken from approximate qlearning, so you'll find it more or less familiar and even simpler."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "__Frameworks__ - we'll accept this homework in any deep learning framework. For example, it translates to TensorFlow almost line-to-line. However, we recommend you to stick to theano/lasagne unless you're certain about your skills in the framework of your choice."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "env: THEANO_FLAGS='floatX=float32'\n"
     ]
    }
   ],
   "source": [
    "%env THEANO_FLAGS = 'floatX=float32'",
    "\n",
    "import os",
    "\n",
    "if type(os.environ.get(\"DISPLAY\")) is not str or len(os.environ.get(\"DISPLAY\")) == 0:",
    "\n",
    "    !bash ../xvfb start",
    "\n",
    "    os.environ['DISPLAY'] = ':1'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[2017-03-14 19:35:59,320] Making new env: CartPole-v0\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7f233bdd3fd0>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAEACAYAAACwB81wAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAFDZJREFUeJzt3X+MXeWd3/H3B2yLBLrxWqTG2G7tbEAb70ZrU2Gqpgm3\nESGm3UIaVcBus0JdukFim6BU7caOqjLpSpRESjZSKyJoyMqbLk6sZBPBNiQ2lNuQSoEk2Alk7MWW\nmC1DsUkJSaFpdm349o97bC5j47kzc2fGvuf9kkZ+7nPOuef5SuPPPPf8uCdVhSRp9J212AOQJC0M\nA1+SWsLAl6SWMPAlqSUMfElqCQNfklpiXgI/yZYk+5McSPLR+diHJGlmMuzr8JOcDfwFcAXwDPBd\n4Leqat9QdyRJmpH5mOFvBg5W1URVHQG+CFwzD/uRJM3AfAT+auDpvteTTZ8kaRHNR+D7XQ2SdBpa\nMg/v+Qywtu/1Wnqz/OOS+EdBkmahqjLbbedjhv894KIk65IsA64D7p26UlWN7M+tt9666GOwPutr\nY32jXFvV3OfJQ5/hV9XRJP8S+CZwNnB3eYWOJC26+TikQ1XdD9w/H+8tSZod77SdB51OZ7GHMK+s\n78w2yvWNcm3DMPQbrwbaaVKLsV9JOpMloU6zk7aSpNOQgS9JLWHgS1JLGPiS1BIGviS1hIEvSS1h\n4EtSSxj4ktQSBr4ktYSBL0ktYeBLUksY+JLUEga+JLWEgS9JLWHgS1JLGPiS1BIGviS1xJyeaZtk\nAvg/wMvAkaranGQF8CXgbwMTwLVV9dM5jlOSNEdzneEX0KmqTVW1uenbCuyuqouBB5vXkqRFNoxD\nOlOfr3g1sL1pbwfeN4R9SJLmaBgz/AeSfC/J7zV9K6vqcNM+DKyc4z4kSUMwp2P4wDuq6tkkbwZ2\nJ9nfv7CqKknNcR+SpCGYU+BX1bPNvz9O8lVgM3A4yQVVdSjJKuC5k207NjZ2vN3pdOh0OnMZiiSN\nnG63S7fbHdr7pWp2E/AkbwTOrqoXk5wL7AI+DlwBPF9Vn0iyFVheVVunbFuz3a8ktVUSqmrqedPB\nt59D4K8Hvtq8XAL8aVX9h+ayzJ3A3+J1Lss08CVp5hYt8OfCwJekmZtr4HunrSS1hIEvSS1h4EtS\nSxj4ktQSBr4ktYSBL0ktYeBLUksY+JLUEga+JLWEgS9JLWHgS1JLGPiS1BIGviS1hIEvSS1h4EtS\nSxj4ktQSBr4ktYSBL0ktYeBLUktMG/hJPp/kcJLH+/pWJNmd5Mkku5Is71u2LcmBJPuTXDlfA5ck\nzcwgM/w/BrZM6dsK7K6qi4EHm9ck2QBcB2xotrkjiZ8iJOk0MG0YV9XDwAtTuq8Gtjft7cD7mvY1\nwI6qOlJVE8BBYPNwhipJmovZzr5XVtXhpn0YWNm0LwQm+9abBFbPch+SpCGa8+GWqiqgTrXKXPch\nSZq7JbPc7nCSC6rqUJJVwHNN/zPA2r711jR9JxgbGzve7nQ6dDqdWQ5FkkZTt9ul2+0O7f3Sm6BP\ns1KyDrivqt7evP4k8HxVfSLJVmB5VW1tTtreQ++4/WrgAeCtNWUnSaZ2SZKmkYSqymy3n3aGn2QH\ncDlwfpKngX8H3A7sTHIjMAFcC1BV40l2AuPAUeBmk12STg8DzfCHvlNn+JI0Y3Od4XuNvCS1hIEv\nSS1h4EtSSxj4ktQSBr4ktYSBL0ktYeBLUksY+JLUEga+JLWEgS9JLWHgS1JLGPiS1BIGviS1hIEv\nSS1h4EtSSxj4ktQSBr4ktYSBL0ktYeBLUktMG/hJPp/kcJLH+/rGkkwm2dP8XNW3bFuSA0n2J7ly\nvgYuSZqZaR9inuSdwEvAn1TV25u+W4EXq+rTU9bdANwDXAqsBh4ALq6qV6as50PMJWmG5v0h5lX1\nMPDCyfZ9kr5rgB1VdaSqJoCDwObZDk6SNDxzOYb/oSQ/SHJ3kuVN34XAZN86k/Rm+pKkRTbbwP8s\nsB7YCDwLfOoU63rsRpJOA0tms1FVPXesneRzwH3Ny2eAtX2rrmn6TjA2Nna83el06HQ6sxmKJI2s\nbrdLt9sd2vtNe9IWIMk64L6+k7arqurZpv0R4NKq+u2+k7abefWk7VunnqH1pK0kzdxcT9pOO8NP\nsgO4HDg/ydPArUAnyUZ6h2ueAm4CqKrxJDuBceAocLPJLkmnh4Fm+EPfqTN8SZqxeb8sU5I0Ggx8\nSWoJA1+SWsLAl6SWMPAlqSUMfElqCQNfklrCwJemeOGpxxj/yh9SLx9d7KFIQ+WNV9IU/3vfw/zl\nw//lhP6/88E7F2E00qu88UqSNBADX5JawsCXpJYw8CWpJQx8SWoJA1+SWsLAl6SWMPAlqSUMfGmK\nk910de7ffMsijEQaLgNfGsCyc5cv9hCkOZs28JOsTfJQkh8leSLJh5v+FUl2J3kyya4ky/u22Zbk\nQJL9Sa6czwIkSYMZZIZ/BPhIVf0a8HeB30/yNmArsLuqLgYebF6TZANwHbAB2ALckcRPEpK0yKYN\n4qo6VFV7m/ZLwD5gNXA1sL1ZbTvwvqZ9DbCjqo5U1QRwENg85HFLkmZoRjPvJOuATcAjwMqqOtws\nOgysbNoXApN9m03S+wMhSVpESwZdMcl5wFeAW6rqxeTVb+isqkpyqu87PmHZ2NjY8Xan06HT6Qw6\nFElqhW63S7fbHdr7DfR9+EmWAn8O3F9Vn2n69gOdqjqUZBXwUFX9apKtAFV1e7PeN4Bbq+qRvvfz\n+/B12vr+XTed0PfL6y/hLe85sV9aSPP+ffjpTeXvBsaPhX3jXuCGpn0D8LW+/uuTLEuyHrgIeHS2\nA5QkDccgh3TeAXwA+GGSPU3fNuB2YGeSG4EJ4FqAqhpPshMYB44CNzudl6TFN23gV9W3ef1PAle8\nzja3AbfNYVySpCHz+nhJagkDX5JawsCXpJYw8KU+f/mtL5y030syNQoMfElqCQNfklrCwJekljDw\nJaklDHxJagkDX5JawsCXpJYw8CWpJQx8SWoJA1+SWsLAl6SWMPAlqSUMfKnPTw5+94S+CzZdtQgj\nkYbPwJf6vHL0r07oW7LsDYswEmn4DHxpGj6QWaNi2sBPsjbJQ0l+lOSJJB9u+seSTCbZ0/xc1bfN\ntiQHkuxPcuV8FiDNtyz2AKQhmfYh5sAR4CNVtTfJecD3k+ymN/H5dFV9un/lJBuA64ANwGrggSQX\nV9UrQx67tCCc4WtUTDvDr6pDVbW3ab8E7KMX5HDyyc81wI6qOlJVE8BBYPNwhistPGf4GhUzOoaf\nZB2wCfhO0/WhJD9IcneS5U3fhcBk32aTvPoHQjrjOMPXqBjkkA4AzeGcLwO3VNVLST4L/Ptm8R8C\nnwJufJ3NT/g/MzY2drzd6XTodDqDDkWSWqHb7dLtdof2fqmafv6SZCnw58D9VfWZkyxfB9xXVW9P\nshWgqm5vln0DuLWqHulbvwbZr7TQvn/XiQ8rX3PZ+1n5G+9dhNFIr5WEqpr1UcZBrtIJcDcw3h/2\nSVb1rfZPgMeb9r3A9UmWJVkPXAQ8OtsBSovNqYlGxSCHdN4BfAD4YZI9Td/HgN9KspHe/4engJsA\nqmo8yU5gHDgK3Ox0XmcyT9pqVEwb+FX1bU7+SeD+U2xzG3DbHMYlnTacrWhUeKetNA1n+BoVBr7U\nOPr/Xjxp/1lLz1ngkUjzw8CXGk/+1z86af+bN1y+wCOR5oeBL0ktYeBLUksY+JLUEga+JLWEgS9J\nLWHgS1JLGPiS1BIGviS1hIEvSS1h4EtSSxj4ktQSAz3xaug79YlXWgATExPs2bNn+hUbq174Hyw9\neuIXqP3PN28ZaPt169axadOmgfcnzdRcn3g18DNtpTPNrl27uOmmEx9ZeDLnLFvCt//j757Q/9OX\nfsH7b3r/QO/xwQ9+kDvvvHNGY5QWkoEv9XnguX92vP3uN3+RP7hz9yKORhouj+FLjf/+43/KL14+\n9/jP1w/dyEtHfnmxhyUNjYEvAS/XEl48uuKE/r8uH36i0XHKwE9yTpJHkuxN8kSSsaZ/RZLdSZ5M\nsivJ8r5ttiU5kGR/kivnefzSUITi7Bxd7GFI8+qUgV9VvwD+QVVtBDYCW5JcBmwFdlfVxcCDzWuS\nbACuAzYAW4A7kvgpQqe9s/Iy5y154YS+FcueXaQRScM37Unbqvp501wGLAUKuBo49ty37UCXXuhf\nA+yoqiPARJKDwGbgO8MdtjRcf3XkZe747Ad48egK3r3pLfyLf3QJv7T0J3xtsQcmDdG0gd/M0B8D\nfgX4T1X1aJKVVXW4WeUwsLJpX8hrw30SWH2y9x30cjlptvbt2zfwulXFgcnngee55/4D3HP/N2e8\nv29961v+Xuu0NsgM/xVgY5I3AV9N8utTlleSU91FddJlq1atOt7udDp0Op2BBiwN6q677uLhhx9e\nsP29613v8jp8DVW326Xb7Q7t/Qa+Dr+qfpbkIeC9wOEkF1TVoSSrgOea1Z4B1vZttqbpO8HY2Njs\nRixJLTF1Mvzxj398Tu833VU65x+7AifJG4D3APuAe4EbmtVugOOHOu8Frk+yLMl64CLg0TmNUJI0\nFNPN8FcB25OcTe+Pw5eq6utJvgPsTHIjMAFcC1BV40l2AuPAUeBmvzRHkk4Ppwz8qnocuOQk/T8B\nrnidbW4DbhvK6CRJQ+M18pLUEga+JLWEgS9JLeEDUDSynnrqKR577LEF29/69eu55JITTnlJQzPX\nB6AY+JJ0hphr4HtIR5JawsCXpJYw8CWpJQx8SWoJA1+SWsLAl6SWMPAlqSUMfElqCQNfklrCwJek\nljDwJaklDHxJagkDX5JaYrqHmJ+T5JEke5M8kWSs6R9LMplkT/NzVd8225IcSLI/yZXzPH5J0oCm\n/XrkJG+sqp8nWQJ8G7gF2AK8WFWfnrLuBuAe4FJgNfAAcHFVvTJlPb8eWZJmaN6/Hrmqft40lwFL\ngWNJfbKdXgPsqKojVTUBHAQ2z3ZwkqThmTbwk5yVZC9wGNhVVY82iz6U5AdJ7k6yvOm7EJjs23yS\n3kxfkrTIBpnhv1JVG4E1wGVJfg34LLAe2Ag8C3zqVG8xjIFKkuZmyaArVtXPkjwEbKmq4wGf5HPA\nfc3LZ4C1fZutafpOMDY2drzd6XTodDoDD1qS2qDb7dLtdof2fqc8aZvkfOBoVf00yRuAbwK3A49V\n1aFmnY8Al1bVb/edtN3Mqydt3zr1DK0nbSVp5uZ60na6Gf4qYHuSs+kd/vlSVX09yZ8k2UjvcM1T\nwE0AVTWeZCcwDhwFbjbZJen0MO1lmfOyU2f4kjRj835ZpiRpNBj4ktQSBr4ktYSBL0ktYeBLUksY\n+JLUEga+JLWEgS9JLWHgS1JLGPiS1BIGviS1hIEvSS1h4EtSSxj4ktQSBr4ktYSBL0ktYeBLUksY\n+JLUEga+JLXEQIGf5Owke5Lc17xekWR3kieT7EqyvG/dbUkOJNmf5Mr5GrgkaWYGneHfAowDx548\nvhXYXVUXAw82r0myAbgO2ABsAe5I0rpPEd1ud7GHMK+s78w2yvWNcm3DMG0YJ1kD/EPgc8Cxp6Vf\nDWxv2tuB9zXta4AdVXWkqiaAg8DmYQ74TDDqv3TWd2Yb5fpGubZhGGT2/UfAvwFe6etbWVWHm/Zh\nYGXTvhCY7FtvElg910FKkubulIGf5DeB56pqD6/O7l+jqopXD/WcdJXZD0+SNCzp5fXrLExuA34H\nOAqcA/wS8GfApUCnqg4lWQU8VFW/mmQrQFXd3mz/DeDWqnpkyvv6R0CSZqGqTjr5HsQpA/81KyaX\nA/+6qv5xkk8Cz1fVJ5qQX15VW5uTtvfQO26/GngAeGsNuhNJ0rxZMsP1jwX37cDOJDcCE8C1AFU1\nnmQnvSt6jgI3G/aSdHoYeIYvSTqzLfg18km2NDdlHUjy0YXe/zAk+XySw0ke7+sbiZvRkqxN8lCS\nHyV5IsmHm/5Rqe+cJI8k2dvUN9b0j0R9x4zyzZJJJpL8sKnv0aZvJOpLsjzJl5PsSzKe5LKh1lZV\nC/YDnE3v2vx1wFJgL/C2hRzDkOp4J7AJeLyv75PAHzTtjwK3N+0NTZ1Lm7oPAmctdg2nqO0CYGPT\nPg/4C+Bto1JfM+Y3Nv8uAb4DXDZK9TXj/lfAnwL3jtLvZzPmp4AVU/pGoj569zX9bt/v55uGWdtC\nz/A3AweraqKqjgBfpHez1hmlqh4GXpjSPRI3o1XVoara27RfAvbROwE/EvUBVNXPm+Yyev9ZihGq\nryU3S069UuWMry/Jm4B3VtXnAarqaFX9jCHWttCBvxp4uu/1KN2YNXI3oyVZR++TzCOMUH1Jzkqy\nl14du6rqUUaoPkb/ZskCHkjyvSS/1/SNQn3rgR8n+eMkjyX5z0nOZYi1LXTgt+IMcfU+b53RN6Ml\nOQ/4CnBLVb3Yv+xMr6+qXqmqjcAa4LIkvz5l+RlbX0tulnxHVW0CrgJ+P8k7+xeewfUtAS4B7qiq\nS4D/S/M9ZcfMtbaFDvxngLV9r9fy2r9QZ7LDSS4AaG5Ge67pn1rzmqbvtJVkKb2w/0JVfa3pHpn6\njmk+Lj8EvJfRqe/vAVcneQrYAbw7yRcYnfqoqmebf38MfJXeYYxRqG8SmKyq7zavv0zvD8ChYdW2\n0IH/PeCiJOuSLKP3zZr3LvAY5su9wA1N+wbga3391ydZlmQ9cBHw6CKMbyBJAtwNjFfVZ/oWjUp9\n5x+7yiHJG4D30DtPMRL1VdXHqmptVa0Hrgf+W1X9DiNSX5I3JvkbTftc4ErgcUagvqo6BDyd5OKm\n6wrgR8B9DKu2RTgLfRW9Kz8OAtsW+6z4LGvYAfwv4K/pnZP458AKencWPwnsonf38bH1P9bUux94\n72KPf5ra/j69Y797gT3Nz5YRqu/twGPAD+gFxb9t+keivim1Xs6rV+mMRH30jnPvbX6eOJYhI1Tf\nbwDfbX4//4zeVTpDq80brySpJVr3cBJJaisDX5JawsCXpJYw8CWpJQx8SWoJA1+SWsLAl6SWMPAl\nqSX+P7qsoBUsM5jGAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7f233ea3cfd0>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import gym",
    "\n",
    "import numpy as np",
    "\n",
    "import pandas as pd",
    "\n",
    "import matplotlib.pyplot as plt",
    "\n",
    "%matplotlib inline",
    "\n",
    "\n",
    "env = gym.make(\"CartPole-v0\").env",
    "\n",
    "env.reset()",
    "\n",
    "n_actions = env.action_space.n",
    "\n",
    "state_dim = env.observation_space.shape",
    "\n",
    "\n",
    "plt.imshow(env.render(\"rgb_array\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Building the network for REINFORCE"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For REINFORCE algorithm, we'll need a model that predicts action probabilities given states."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import theano",
    "\n",
    "import theano.tensor as T",
    "\n",
    "\n",
    "# create input variables. We'll support multiple states at once",
    "\n",
    "\n",
    "states = T.matrix(\"states[batch,units]\")",
    "\n",
    "actions = T.ivector(\"action_ids[batch]\")",
    "\n",
    "cumulative_rewards = T.vector(\"G[batch] = r + gamma*r' + gamma^2*r'' + ...\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import lasagne",
    "\n",
    "from lasagne.layers import *",
    "\n",
    "\n",
    "# input layer",
    "\n",
    "l_states = InputLayer((None,)+state_dim, input_var=states)",
    "\n",
    "\n",
    "\n",
    "<Your architecture. Please start with a 1-2 layers with 50-200 neurons >",
    "\n",
    "\n",
    "# output layer",
    "\n",
    "# this time we need to predict action probabilities,",
    "\n",
    "# so make sure your nonlinearity forces p>0 and sum_p = 1",
    "\n",
    "l_action_probas = DenseLayer( < ... > ,",
    "\n",
    "                             num_units= < ... > ,",
    "\n",
    "                             nonlinearity= < ... > )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Predict function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get probabilities of actions",
    "\n",
    "predicted_probas = get_output(l_action_probas)",
    "\n",
    "\n",
    "# predict action probability given state",
    "\n",
    "# if you use float32, set allow_input_downcast=True",
    "\n",
    "predict_proba = <compile a function that takes states and returns predicted_probas >"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Loss function and updates\n",
    "\n",
    "We now need to define objective and update over policy gradient.\n",
    "\n",
    "Our objective function is\n",
    "\n",
    "$$ J \\approx  { 1 \\over N } \\sum  _{s_i,a_i} \\pi_\\theta (a_i | s_i) \\cdot G(s_i,a_i) $$\n",
    "\n",
    "\n",
    "Following the REINFORCE algorithm, we can define our objective as follows: \n",
    "\n",
    "$$ \\hat J \\approx { 1 \\over N } \\sum  _{s_i,a_i} log \\pi_\\theta (a_i | s_i) \\cdot G(s_i,a_i) $$\n",
    "\n",
    "When you compute gradient of that function over network weights $ \\theta $, it will become exactly the policy gradient.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# select probabilities for chosen actions, pi(a_i|s_i)",
    "\n",
    "predicted_probas_for_actions = predicted_probas[T.arange(",
    "\n",
    "    actions.shape[0]), actions]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# REINFORCE objective function",
    "\n",
    "J =  # <policy objective as in the last formula. Please use mean, not sum.>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# all network weights",
    "\n",
    "all_weights = <get all \"thetas\" aka network weights using lasagne >",
    "\n",
    "\n",
    "# weight updates. maximize J = minimize -J",
    "\n",
    "updates = lasagne.updates.sgd(-J, all_weights, learning_rate=0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_step = theano.function([states, actions, cumulative_rewards], updates=updates,",
    "\n",
    "                             allow_input_downcast=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Computing cumulative rewards"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "def get_cumulative_rewards(rewards,  # rewards at each step",
    "\n",
    "                           gamma=0.99  # discount for reward",
    "\n",
    "                           ):",
    "\n",
    "    \"\"\"",
    "\n",
    "    take a list of immediate rewards r(s,a) for the whole session ",
    "\n",
    "    compute cumulative returns (a.k.a. G(s,a) in Sutton '16)",
    "\n",
    "    G_t = r_t + gamma*r_{t+1} + gamma^2*r_{t+2} + ...",
    "\n",
    "\n",
    "    The simple way to compute cumulative rewards is to iterate from last to first time tick",
    "\n",
    "    and compute G_t = r_t + gamma*G_{t+1} recurrently",
    "\n",
    "\n",
    "    You must return an array/list of cumulative rewards with as many elements as in the initial rewards.",
    "\n",
    "    \"\"\"",
    "\n",
    "\n",
    "    <your code here >",
    "\n",
    "\n",
    "    return < array of cumulative rewards >"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert len(get_cumulative_rewards(range(100))) == 100",
    "\n",
    "assert np.allclose(get_cumulative_rewards([0, 0, 1, 0, 0, 1, 0], gamma=0.9), [",
    "\n",
    "                   1.40049, 1.5561, 1.729, 0.81, 0.9, 1.0, 0.0])",
    "\n",
    "assert np.allclose(get_cumulative_rewards(",
    "\n",
    "    [0, 0, 1, -2, 3, -4, 0], gamma=0.5), [0.0625, 0.125, 0.25, -1.5, 1.0, -4.0, 0.0])",
    "\n",
    "assert np.allclose(get_cumulative_rewards(",
    "\n",
    "    [0, 0, 1, 2, 3, 4, 0], gamma=0), [0, 0, 1, 2, 3, 4, 0])",
    "\n",
    "print(\"looks good!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Playing the game"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_session(t_max=1000):",
    "\n",
    "    \"\"\"play env with REINFORCE agent and train at the session end\"\"\"",
    "\n",
    "\n",
    "    # arrays to record session",
    "\n",
    "    states, actions, rewards = [], [], []",
    "\n",
    "\n",
    "    s = env.reset()",
    "\n",
    "\n",
    "    for t in range(t_max):",
    "\n",
    "\n",
    "        # action probabilities array aka pi(a|s)",
    "\n",
    "        action_probas = predict_proba([s])[0]",
    "\n",
    "\n",
    "        a = <sample action with given probabilities >",
    "\n",
    "\n",
    "        new_s, r, done, info = env.step(a)",
    "\n",
    "\n",
    "        # record session history to train later",
    "\n",
    "        states.append(s)",
    "\n",
    "        actions.append(a)",
    "\n",
    "        rewards.append(r)",
    "\n",
    "\n",
    "        s = new_s",
    "\n",
    "        if done:",
    "\n",
    "            break",
    "\n",
    "\n",
    "    cumulative_rewards = get_cumulative_rewards(rewards)",
    "\n",
    "    train_step(states, actions, cumulative_rewards)",
    "\n",
    "\n",
    "    return sum(rewards)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mean reward:20.900\n",
      "mean reward:35.860\n",
      "mean reward:50.820\n",
      "mean reward:88.550\n",
      "mean reward:132.080\n",
      "mean reward:165.890\n",
      "mean reward:193.790\n",
      "mean reward:166.510\n",
      "mean reward:120.910\n",
      "mean reward:98.450\n",
      "mean reward:236.340\n",
      "mean reward:280.410\n",
      "mean reward:317.610\n",
      "You Win!\n"
     ]
    }
   ],
   "source": [
    "for i in range(100):",
    "\n",
    "\n",
    "    rewards = [generate_session() for _ in range(100)]  # generate new sessions",
    "\n",
    "\n",
    "    print(\"mean reward:%.3f\" % (np.mean(rewards)))",
    "\n",
    "\n",
    "    if np.mean(rewards) > 300:",
    "\n",
    "        print(\"You Win!\")",
    "\n",
    "        break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Video"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[2017-03-14 19:36:45,862] Making new env: CartPole-v0\n",
      "[2017-03-14 19:36:45,870] DEPRECATION WARNING: env.spec.timestep_limit has been deprecated. Replace your call to `env.spec.timestep_limit` with `env.spec.tags.get('wrapper_config.TimeLimit.max_episode_steps')`. This change was made 12/28/2016 and is included in version 0.7.0\n",
      "[2017-03-14 19:36:45,873] Clearing 12 monitor files from previous run (because force=True was provided)\n",
      "[2017-03-14 19:36:45,894] Starting new video recorder writing to /home/jheuristic/Downloads/Practical_RL/week6/videos/openaigym.video.0.7776.video000000.mp4\n",
      "[2017-03-14 19:36:51,516] Starting new video recorder writing to /home/jheuristic/Downloads/Practical_RL/week6/videos/openaigym.video.0.7776.video000001.mp4\n",
      "[2017-03-14 19:36:57,580] Starting new video recorder writing to /home/jheuristic/Downloads/Practical_RL/week6/videos/openaigym.video.0.7776.video000008.mp4\n",
      "[2017-03-14 19:37:05,049] Starting new video recorder writing to /home/jheuristic/Downloads/Practical_RL/week6/videos/openaigym.video.0.7776.video000027.mp4\n",
      "[2017-03-14 19:37:08,785] Starting new video recorder writing to /home/jheuristic/Downloads/Practical_RL/week6/videos/openaigym.video.0.7776.video000064.mp4\n",
      "[2017-03-14 19:37:11,505] Finished writing results. You can upload them to the scoreboard via gym.upload('/home/jheuristic/Downloads/Practical_RL/week6/videos')\n"
     ]
    }
   ],
   "source": [
    "# record sessions",
    "\n",
    "import gym.wrappers",
    "\n",
    "env = gym.wrappers.Monitor(gym.make(\"CartPole-v0\"),",
    "\n",
    "                           directory=\"videos\", force=True)",
    "\n",
    "sessions = [generate_session() for _ in range(100)]",
    "\n",
    "env.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "<video width=\"640\" height=\"480\" controls>\n",
       "  <source src=\"./videos/openaigym.video.0.7776.video000001.mp4\" type=\"video/mp4\">\n",
       "</video>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# show video",
    "\n",
    "from IPython.display import HTML",
    "\n",
    "import os",
    "\n",
    "\n",
    "video_names = list(",
    "\n",
    "    filter(lambda s: s.endswith(\".mp4\"), os.listdir(\"./videos/\")))",
    "\n",
    "\n",
    "HTML(\"\"\"",
    "\n",
    "<video width=\"640\" height=\"480\" controls>",
    "\n",
    "  <source src=\"{}\" type=\"video/mp4\">",
    "\n",
    "</video>",
    "\n",
    "\"\"\".format(\"./videos/\"+video_names[-1]))  # this may or may not be _last_ video. Try other indices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
